Actual source code: fsolvebaij.F

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
 7:  #include include/finclude/petscdef.h
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*),b(0:*)
 14:       PetscInt    n,ai(0:*),aj(0:*),adiag(0:*)

 16:       PetscInt    i,j,jstart,jend,idx,ax,jdx
 17:       PetscScalar s1,s2,s3,s4
 18:       PetscScalar x1,x2,x3,x4
 19: !
 20: !     Forward Solve
 21: !

 23:       x(0) = b(0)
 24:       x(1) = b(1)
 25:       x(2) = b(2)
 26:       x(3) = b(3)
 27:       idx  = 0
 28:       do 20 i=1,n-1
 29:          jstart = ai(i)
 30:          jend   = adiag(i) - 1
 31:          ax    = 16*jstart
 32:          idx    = idx + 4
 33:          s1     = b(idx)
 34:          s2     = b(idx+1)
 35:          s3     = b(idx+2)
 36:          s4     = b(idx+3)
 37:          do 30 j=jstart,jend
 38:            jdx   = 4*aj(j)
 39: 
 40:            x1    = x(jdx)
 41:            x2    = x(jdx+1)
 42:            x3    = x(jdx+2)
 43:            x4    = x(jdx+3)
 44:            s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 45:            s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 46:            s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 47:            s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 48:            ax = ax + 16
 49:  30      continue
 50:          x(idx)   = s1
 51:          x(idx+1) = s2
 52:          x(idx+2) = s3
 53:          x(idx+3) = s4
 54:  20   continue
 55: 
 56: !
 57: !     Backward solve the upper triangular
 58: !
 59:       do 40 i=n-1,0,-1
 60:          jstart  = adiag(i) + 1
 61:          jend    = ai(i+1) - 1
 62:          ax     = 16*jstart
 63:          s1      = x(idx)
 64:          s2      = x(idx+1)
 65:          s3      = x(idx+2)
 66:          s4      = x(idx+3)
 67:          do 50 j=jstart,jend
 68:            jdx   = 4*aj(j)
 69:            x1    = x(jdx)
 70:            x2    = x(jdx+1)
 71:            x3    = x(jdx+2)
 72:            x4    = x(jdx+3)
 73:            s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 74:            s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 75:            s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 76:            s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 77:            ax = ax + 16
 78:  50      continue
 79:          ax      = 16*adiag(i)
 80:          x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 81:          x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 82:          x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 83:          x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 84:          idx      = idx - 4
 85:  40   continue
 86:       return
 87:       end
 88: 
 89: !
 90: !   version that calls BLAS 2 operation for each row block
 91: !
 92:       subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
 93:       implicit none
 94:       MatScalar   a(0:*),w(0:*)
 95:       PetscScalar x(0:*),b(0:*)
 96:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)

 98:       PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
 99:       MatScalar   s(0:3)
100: !
101: !     Forward Solve
102: !

104:       x(0) = b(0)
105:       x(1) = b(1)
106:       x(2) = b(2)
107:       x(3) = b(3)
108:       idx  = 0
109:       do 20 i=1,n-1
110: !
111: !        Pack required part of vector into work array
112: !
113:          kdx    = 0
114:          jstart = ai(i)
115:          jend   = adiag(i) - 1
116:          if (jend - jstart .ge. 500) then
117:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
118:          endif
119:          do 30 j=jstart,jend
120: 
121:            jdx       = 4*aj(j)
122: 
123:            w(kdx)    = x(jdx)
124:            w(kdx+1)  = x(jdx+1)
125:            w(kdx+2)  = x(jdx+2)
126:            w(kdx+3)  = x(jdx+3)
127:            kdx       = kdx + 4
128:  30      continue

130:          ax      = 16*jstart
131:          idx      = idx + 4
132:          s(0)     = b(idx)
133:          s(1)     = b(idx+1)
134:          s(2)     = b(idx+2)
135:          s(3)     = b(idx+3)
136: !
137: !    s = s - a(ax:)*w
138: !
139:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
140: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)

142:          x(idx)   = s(0)
143:          x(idx+1) = s(1)
144:          x(idx+2) = s(2)
145:          x(idx+3) = s(3)
146:  20   continue
147: 
148: !
149: !     Backward solve the upper triangular
150: !
151:       do 40 i=n-1,0,-1
152:          jstart    = adiag(i) + 1
153:          jend      = ai(i+1) - 1
154:          ax       = 16*jstart
155:          s(0)      = x(idx)
156:          s(1)      = x(idx+1)
157:          s(2)      = x(idx+2)
158:          s(3)      = x(idx+3)
159: !
160: !   Pack each chunk of vector needed
161: !
162:          kdx = 0
163:          if (jend - jstart .ge. 500) then
164:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
165:          endif
166:          do 50 j=jstart,jend
167:            jdx      = 4*aj(j)
168:            w(kdx)   = x(jdx)
169:            w(kdx+1) = x(jdx+1)
170:            w(kdx+2) = x(jdx+2)
171:            w(kdx+3) = x(jdx+3)
172:            kdx      = kdx + 4
173:  50      continue
174: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
175:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)

177:          ax      = 16*adiag(i)
178:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
179:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
180:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
181:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
182:          idx     = idx - 4
183:  40   continue
184:       return
185:       end
186: 

188: !
189: !   version that does not call BLAS 2 operation for each row block
190: !
191:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
192:       implicit none
193:       MatScalar   a(0:*)
194:       PetscScalar x(0:*),b(0:*),w(0:*)
195:       PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
196:       PetscInt ii,jj,i,j

198:       PetscInt jstart,jend,idx,ax,jdx,kdx,nn
199:       PetscScalar s(0:3)
200: !
201: !     Forward Solve
202: !

204:       x(0) = b(0)
205:       x(1) = b(1)
206:       x(2) = b(2)
207:       x(3) = b(3)
208:       idx  = 0
209:       do 20 i=1,n-1
210: !
211: !        Pack required part of vector into work array
212: !
213:          kdx    = 0
214:          jstart = ai(i)
215:          jend   = adiag(i) - 1
216:          if (jend - jstart .ge. 500) then
217:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
218:          endif
219:          do 30 j=jstart,jend
220: 
221:            jdx       = 4*aj(j)
222: 
223:            w(kdx)    = x(jdx)
224:            w(kdx+1)  = x(jdx+1)
225:            w(kdx+2)  = x(jdx+2)
226:            w(kdx+3)  = x(jdx+3)
227:            kdx       = kdx + 4
228:  30      continue

230:          ax       = 16*jstart
231:          idx      = idx + 4
232:          s(0)     = b(idx)
233:          s(1)     = b(idx+1)
234:          s(2)     = b(idx+2)
235:          s(3)     = b(idx+3)
236: !
237: !    s = s - a(ax:)*w
238: !
239:          nn = 4*(jend - jstart + 1) - 1
240:          do 100, ii=0,3
241:            do 110, jj=0,nn
242:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
243:  110       continue
244:  100     continue

246:          x(idx)   = s(0)
247:          x(idx+1) = s(1)
248:          x(idx+2) = s(2)
249:          x(idx+3) = s(3)
250:  20   continue
251: 
252: !
253: !     Backward solve the upper triangular
254: !
255:       do 40 i=n-1,0,-1
256:          jstart    = adiag(i) + 1
257:          jend      = ai(i+1) - 1
258:          ax        = 16*jstart
259:          s(0)      = x(idx)
260:          s(1)      = x(idx+1)
261:          s(2)      = x(idx+2)
262:          s(3)      = x(idx+3)
263: !
264: !   Pack each chunk of vector needed
265: !
266:          kdx = 0
267:          if (jend - jstart .ge. 500) then
268:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
269:          endif
270:          do 50 j=jstart,jend
271:            jdx      = 4*aj(j)
272:            w(kdx)   = x(jdx)
273:            w(kdx+1) = x(jdx+1)
274:            w(kdx+2) = x(jdx+2)
275:            w(kdx+3) = x(jdx+3)
276:            kdx      = kdx + 4
277:  50      continue
278:          nn = 4*(jend - jstart + 1) - 1
279:          do 200, ii=0,3
280:            do 210, jj=0,nn
281:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
282:  210       continue
283:  200     continue

285:          ax      = 16*adiag(i)
286:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
287:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
288:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
289:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
290:          idx     = idx - 4
291:  40   continue
292:       return
293:       end
294: